from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import openai
from collections import defaultdict
from datetime import datetime

class OpenBookQAReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("OpenBookQA")
        self.config.dataset_path = "datasets/OpenBookQA.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load OpenBookQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def execute_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Execute full reasoning workflow for an OpenBookQA problem"""
        try:
            question = problem["question_stem"]
            choices = problem["choices"]
            options = {
                "A": choices["text"][0],
                "B": choices["text"][1],
                "C": choices["text"][2],
                "D": choices["text"][3]
            }
            
            # Step 1: Create root node
            root = self._create_node(
                question=question,
                options=options,
                constraints={},
                path=[],
                method={"description": "Original problem"}
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step 2: Extract constraints
            constraints = await self._extract_constraints(question, options)
            root.constraints = constraints
            self._log_step("step2", root.node_id, {"constraints": constraints})
            
            # Step 3: Explore solution methods
            methods = await self._explore_solutions(question, options)
            self._log_step("step3", root.node_id, {"methods": methods})
            
            # Step 4: Create method nodes
            method_nodes = []
            for method in methods[:self.config.beam_width]:
                node = self._create_node(
                    path=[root.node_id],
                    question=question,
                    options=options,
                    method=method,
                    constraints=root.constraints,
                    score=method.get("score", 0),
                    parent_id=root.node_id
                )
                root.children.append(node.node_id)
                method_nodes.append(node)
                self._log_step("step4", node.node_id, {"method": method})
            
            # Step 5: Solve the best method node directly
            best_method_node = max(method_nodes, key=lambda x: x.score)
            solution = await self._solve_node(best_method_node.node_id)
            self._log_step("step5", best_method_node.node_id, {"solution": solution})
            
            final_answer = solution["answer"] if solution else "X"
            self._log_step("step6", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs,
                "token_usage": self.llm.token_counts
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    async def _extract_constraints(self, question: str, options: Dict[str, str]) -> Dict[str, Any]:
        """Extract constraints from problem and options"""
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Analyze this question and extract key constraints:

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Identify:
1. Explicit constraints (directly stated)
2. Implicit constraints (logical implications)
3. Key terms and their relationships
4. Spatial/temporal relationships if present
5. Any conditional statements

Output JSON format:
{{
    "explicit": ["list", "of", "constraints"],
    "implicit": ["list", "of", "constraints"],
    "key_terms": ["term1", "term2"],
    "notes": "Analysis summary"
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                return json.loads(response)
            except:
                continue
        
        return {
            "explicit": [],
            "implicit": [],
            "key_terms": [],
            "notes": "Failed to extract constraints"
        }
    
    async def _explore_solutions(self, question: str, options: Dict[str, str]) -> List[Dict]:
        """Step 3: Explore diverse solution methods"""
        options_text = "\n".join([f"{k}. {v}" for k, v in options.items()])
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Generate 3 distinct solution approaches for this question:

Question: {question}
Options:
{options_text}

For each approach, provide:
- Clear description of the reasoning strategy
- Key steps to implement the approach
- Confidence score (0-100) based on:
  * Logical soundness
  * Coverage of options
  * Appropriate use of deductive/inductive reasoning
  * Clarity of reasoning steps

Output JSON format:
{{
    "methods": [
        {{
            "description": "Approach description",
            "steps": ["step1", "step2"],
            "score": 0-100,
            "score_reason": "Scoring justification"
        }}
    ]
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                response = response.strip()
                
                # Handle markdown code blocks
                if response.startswith("```json"):
                    response = response[7:-3].strip()
                elif response.startswith("```"):
                    response = response[3:-3].strip()
                
                data = json.loads(response)
                
                # Validate response
                if not isinstance(data, dict) or "methods" not in data:
                    raise ValueError("Invalid structure: missing 'methods' key")
                    
                methods = data["methods"]
                if len(methods) < 2:
                    raise ValueError(f"Expected at least 2 methods, got {len(methods)}")
                    
                # Validate each method
                required_keys = {"description", "steps", "score", "score_reason"}
                for method in methods:
                    if not all(k in method for k in required_keys):
                        raise ValueError("Missing required keys in method")
                    if not isinstance(method["steps"], list):
                        raise ValueError("Steps must be a list")
                        
                return sorted(methods, key=lambda x: -x["score"])
                
            except (json.JSONDecodeError, ValueError, KeyError) as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    print(f"Final failed response: {response}")
                    return []
                continue
                
        return []  # Fallback if all retries fail
    
    async def _solve_node(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Step 5: Solve individual reasoning node"""
        node = self.nodes[node_id]
        
        # Build context prompt
        context = f"Question: {node.question}\nOptions:\n"
        for opt, text in node.options.items():
            context += f"{opt}. {text}\n"
        
        context += f"\nSolution Approach: {node.method['description']}\n"
        context += f"Constraints: {json.dumps(node.constraints, indent=2)}\n"
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Solve this question using the specified approach:

{context}

Reasoning Steps:
1. Strictly follow the provided approach: {node.method['description']}
2. Execute each step: {', '.join(node.method['steps'])}
3. Consider all constraints
4. Evaluate each option systematically
5. Provide clear justification for inclusion/exclusion
6. Select the best answer

Output Requirements:
- End your response with: "Final Answer: [OPTION]"
- Use \boxed{{[OPTION]}} to denote your answer
- Your answer must be A, B, C, or D
"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None

    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        # Convert nodes to serializable format
        serialized_nodes = {}
        for node_id, node in self.nodes.items():
            serialized_nodes[node_id] = {
                "node_id": node.node_id,
                "question": node.question,
                "options": node.options,
                "method": node.method,
                "constraints": node.constraints,
                "answer": node.answer,
                "state": node.state,
                "score": node.score
            }
        
        # Prepare verification
        selected_answer = result.get("final_answer", "X")
        correct_answer = problem.get("answerKey", "").strip().upper()
        is_correct = self.verify_answer(problem, selected_answer)
        verification = {
            "is_correct": is_correct,
            "correct_answer": correct_answer,
            "given_answer": selected_answer
        }
        return {
            "problem": problem,
            "result": {
                "final_answer": selected_answer,
                "correct_answer": correct_answer,
                "is_correct": is_correct,
                "nodes": serialized_nodes,
                "token_usage": result.get("token_usage", [0, 0])
            },
            "verification": verification
        }

    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        # Match \boxed{answer} pattern
        boxed_pattern = r'\\boxed\{([A-D])\}'
        boxed_match = re.search(boxed_pattern, text)
        if boxed_match:
            return boxed_match.group(1)
        
        # Match "Answer: X" pattern
        answer_pattern = r'Answer:\s*([A-D])'
        answer_match = re.search(answer_pattern, text, re.IGNORECASE)
        if answer_match:
            return answer_match.group(1)
        
        # Match standalone option letter
        option_pattern = r'\b([A-D])\b(?!\.\w)'
        option_match = re.search(option_pattern, text)
        if option_match:
            return option_match.group(1)
        
        return None

    def verify_answer(self, problem: Dict[str, Any], selected_answer: str) -> bool:
        """Verify if selected answer matches correct option"""
        correct_answer = problem.get("answerKey", "").strip().upper()
        return selected_answer.upper() == correct_answer.upper()